from bluepyopt.ephys.responses import TimeVoltageResponse
import matplotlib as mpl
import json
import numpy as np
from bluepyparallel import init_parallel_factory
import yaml
from functools import partial

from pathlib import Path
import matplotlib.pyplot as plt
import pickle

# from emodel_generalisation.utils import plot_traces
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import save_selected_emodels
import logging
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.utils import get_combo_hash
from emodel_generalisation.utils import get_feature_df
from datareuse import Reuse
import extra_features

if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    name = "runaway"
    access_point = AccessPoint(
        emodel_dir=".",
        recipes_path="config/recipes.json",
        final_path=f"final_{name}.json",
        with_seeds=True,
    )
    exemplar_data = yaml.safe_load(open("../../mcmc_run/exemplar_data.yaml"))
    access_point.morph_path = exemplar_data["paths"]["all"]
    access_point.settings["morph_modifiers"] = [
        partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
        partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
    ]

    df = pd.DataFrame()
    final = json.load(open(f"final_{name}.json"))
    emodel = "simplest_0"
    values = np.linspace(-3, 3, 7) + 2.959295568009344
    for i, val in enumerate(values):
        df.loc[i, "name"] = f"shift_{i}"
        df.loc[i, "emodel"] = emodel
        df.loc[i, "new_parameters"] = json.dumps({"shift_it2.somadend": val})

    print(df)
    with Reuse(f"shift_scan_eval_{name}.csv") as reuse:
        df = reuse(
            feature_evaluation,
            df,
            access_point,
            parallel_factory=parallel_factory,
            trace_data_path="traces",
            # record_ions_and_currents=True,
        )

    d = get_feature_df(df)

    print(d)
    columns = [c for c in d.columns if len(c.split(".")[0].split("_")) == 4]
    cmap = plt.get_cmap("coolwarm")
    colors = cmap(np.linspace(0, 1, len(d)))
    h = np.linspace(-80, -55, 10)
    data_df = pd.DataFrame(columns=h, dtype=float)
    plt.figure(figsize=(6, 3))
    for gid, c in zip(d.index, colors):
        _df = pd.DataFrame()
        i = 0
        for col in columns:
            if col.endswith(".voltage_base"):
                _df.loc[i, "burst_number"] = d.loc[
                    gid, ".".join(col.split(".")[:-1] + ["all_burst_number"])
                ]
                _df.loc[i, "holding_voltage"] = d.loc[gid, col]
                i += 1
        if gid == 3:
            c = "C2"
        plt.plot(_df["holding_voltage"], _df["burst_number"], marker="+", c=c)
        plt.axvline(
            -values[gid] - 65.3,
            c=c,
            ls="--",
        )  # mean value from activation curves
        data_df.loc[values[gid], h] = np.array(
            np.interp(h, _df["holding_voltage"], _df["burst_number"])
        )
    plt.xlabel("holding_voltage")
    plt.ylabel("burst_number")
    # plt.legend()
    norm = mpl.colors.Normalize(vmin=values[0], vmax=values[-1])
    plt.colorbar(
        mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
        ax=plt.gca(),
        orientation="vertical",
        shrink=0.6,
        label="voltage shift",
    )
    plt.tight_layout()
    plt.axis([-80, -50, 0, 35])
    plt.savefig(f"shift_burst_scan_{name}.pdf")
    plt.close()

    fig, ax = plt.subplots(figsize=(5, 5))
    plt.imshow(
        data_df,
        origin="upper",
        extent=[-80, -50, -values[-1] - 65.3, -values[0] - 65.3],
        cmap="Blues",
    )
    plt.colorbar(shrink=0.2, label="burst number")
    # plt.axis("equal")
    plt.xlabel("holding voltage [mV]")
    plt.ylabel("voltage window center [mV]")
    plt.tight_layout()
    plt.savefig("shift_heatmap.pdf")
